iT邦幫忙

2023 iThome 鐵人賽

0

簡介

最近 Hugging Face Transformers 整合了 Flash Attention 2,可以減少記憶體消耗並提昇模型運算的速度,且使用方式非常簡單,來分享一下這個用法。

Flash Attention

Flash Attention 為去年五月 Stanford University 提出的論文,作者設計了一個 IO-Aware 的演算法,根據裝置的 IO 速度來最佳化 Attention 的運算,並將 Softmax 運算拆解開來,以減少 GPU 記憶體的消耗。在今年七月作者又發表了一篇 Flash Attention 2 的論文,進一步提昇了 Flash Attention 的速度。

在 Hugging Face Text Generation Inference (TGI) 裡面,很早就整合了 Flash Attention 的技術,一直到兩週前 HF Transformers 才完成 Flash Attention 的整合。在 HF Transformers 裡面調用 Flash Attention 2 相當簡單,只要加上 use_flash_attention_2 的參數即可:

from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
    "TheBloke/Llama-2-7b-chat-fp16",
    device_map="auto",
    use_flash_attention_2=False,
)

model 實際印出來,可以看到 Attention Layer 變成 Flash Attention 2 的版本:

...
(0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaFlashAttention2(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
    )
    ...
)
...

透過以下程式碼來粗略估計推論時會用到的記憶體:

import torch
from transformers import LlamaForCausalLM as ModelCls

model: ModelCls = ModelCls.from_pretrained(
    "TheBloke/Llama-2-7b-chat-fp16",
    device_map="auto",
    load_in_8bit=True,
    use_flash_attention_2=True,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")

將模型量化為 8-Bit 時,權重本身約佔用 7 GiB。當我們使用 Flash Attention 2 對長度 3K 的輸入進行推論時,約需要消耗 13.1 GiB。若把 Flash Attention 2 關掉的話,則要消耗 16.5 GiB。當長度越長,消耗的記憶體差距越大,將不同 Batch Size 與 Sequence Length 的關係畫成線圖如下:

Flash Attention 2

實線為使用 Flash Attention 2,而虛線則沒有使用,可以看到 Flash Attention 2 的記憶體消耗呈現線性關係,而原本的 Attention 則是平方成長上去。

目前實測起來,記憶體部份似乎只有推論階段受益於 Flash Attention 機制,訓練階段似乎沒有變化。速度部份也許有變化,但筆者尚未完成這個部份的測試。

參考


上一篇
LLM Note Day 30 - 學海無涯,學無止境
下一篇
LLM Note Day 32 - AutoGPTQ
系列文
LLM 學習筆記33
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言